""" Plot figures
"""
from __future__ import (absolute_import, division,
                        print_function, unicode_literals)

__all__ = ['surface',
           'contourf',
           'savefig',
           ]

__date__ = '06/13/2017'
__version__ = '0.1.0'
__author__ = 'J.G. Chen'
__email__ = 'cjgls@pku.edu.cn'

import matplotlib.pyplot as plt
from matplotlib import rcParams
from matplotlib import docstring
import numpy as np
from mpl_toolkits.mplot3d import Axes3D


@docstring.Appender(Axes3D.plot_surface.__doc__)
def surface(*args, **kwargs):
    """ Modefied :meth:`mpl_toolkits.mplot3d.Axes3D.plot_surface` function

    By default plot_surface needs three 2D arrays: *X*, *Y*, *Z*,
    in this function, *args include at least one 2D-array which
    is passed on to *Z* and *X*, *Y* will be autogenerated
    according to the nargs and the size of each arg.

    =========== ==============================================
    Argument    Description
    =========== ==============================================
    *ax*        Axes3D object
    *X*, *Y*    Optional 1D or 2D arrays
    *Z*         Data values as 2D array
    =========== ==============================================

    Other arguments are passed on to
    :meth:`mpl_toolkits.mplot3d.Axes3D.plot_surface`

    """

    nargs = len(args)
    ax = None
    kwargs['cmap'] = kwargs.pop('cmap', 'jet')
    kwargs['cstride'] = kwargs.pop('cstride', 2)
    kwargs['rstride'] = kwargs.pop('rstride', 2)
    kwargs['lw'] = kwargs.pop('lw', 0)
    xlabel = kwargs.pop('xlabel', 'Y')
    ylabel = kwargs.pop('ylabel', 'X')
    zlabel = kwargs.pop('zlabel', None)
    title = kwargs.pop('title', None)

    # generate three 2D arrays: X, Y, Z
    if nargs == 1:      # only pass *Z*
        Z = args[0]
        nx, ny = Z.shape
        X, Y = np.meshgrid(np.arange(ny), np.arange(nx))
        args = ()
    elif nargs == 2:    # ax, Z
        ax = args[0]
        Z = args[1]
        nx, ny = Z.shape
        X, Y = np.meshgrid(np.arange(ny), np.arange(nx))
        args = ()
    elif nargs == 3:    # X, Y, Z
        X = args[0]
        Y = args[1]
        Z = args[2]
        args = ()
    elif nargs >= 4:    # ax, X, Y, Z
        ax = args[0]
        X = args[1]
        Y = args[2]
        Z = args[3]
        args = args[4:]

    if X.ndim == Y.ndim:     # same dimesion
        if X.ndim == 1:      # 1D array
            X, Y = np.meshgrid(Y, X)
        elif (X.ndim == 2) and (X.shape == Y.shape):
            pass
        else:
            raise ValueError('X and Y must be 1D or same shape 2D arrays')
    else:
        raise ValueError('X and Y must have same dimension')

    if not ax:
        fig = plt.gcf()
        ax = fig.add_subplot(111, projection='3d')

    ax.set_xlabel(xlabel, labelpad=20)
    ax.set_ylabel(ylabel, labelpad=20)
    if zlabel:
        ax.set_zlabel(zlabel, labelpad=20)
    if title:
        ax.set_title(title)
    cs = ax.plot_surface(X, Y, Z, *args, **kwargs)
    return cs


@docstring.Appender(Axes3D.contourf.__doc__)
def contourf(*args, **kwargs):
    """ Modefied :meth:`mpl_toolkits.mplot3d.Axes3D.contourf` function

    By default contourf() needs three 2D arrays: *X*, *Y*, *Z*,
    in this function, *args include at least one 2D-array which
    is passed on to *Z* and *X*, *Y* will be autogenerated
    according to the nargs and the size of each arg.

    =========== ==============================================
    Argument    Description
    =========== ==============================================
    *ax*        Axes3D object
    *X*, *Y*    Optional 1D or 2D arrays
    *Z*         Data values as 2D array
    =========== ==============================================

    Other arguments are passed on to
    :meth:`mpl_toolkits.mplot3d.Axes3D.contourf`

    """
    nargs = len(args)
    ax = None

    # generate three 2D arrays: X, Y, Z
    if nargs == 1:      # only pass *Z*
        Z = args[0]
        nx, ny = Z.shape
        X, Y = np.meshgrid(np.arange(ny), np.arange(nx))
        args = ()
    elif nargs == 2:    # ax, Z
        ax = args[0]
        Z = args[1]
        nx, ny = Z.shape
        X, Y = np.meshgrid(np.arange(ny), np.arange(nx))
        args = ()
    elif nargs == 3:    # X, Y, Z
        X = args[0]
        Y = args[1]
        Z = args[2]
        args = ()
    elif nargs >= 4:    # ax, X, Y, Z
        ax = args[0]
        X = args[1]
        Y = args[2]
        Z = args[3]
        args = args[4:]

    if X.ndim == Y.ndim:     # same dimesion
        if X.ndim == 1:      # 1D array
            X, Y = np.meshgrid(Y, X)
        elif (X.ndim == 2) and (X.shape == Y.shape):
            pass
        else:
            raise ValueError('X and Y must be 1D or same shape 2D arrays')
    else:
        raise ValueError('X and Y must have same dimension')

    if not ax:
        fig = plt.gcf()
        ax = fig.add_subplot(111, projection='3d')

    cs = ax.contourf(X, Y, Z, *args, **kwargs)

    return cs


def savefig(fname, handle=None, format=None, **kwargs):
    """ Modefied function :func:`matplotlib.pyplot.savefig`

    Parameters
    ----------
    handle: figure object
        The figure object which is to be saved. By default [None],
        the value is given by :func:`matplotlib.pyplot.gcf`.

    format: str
        One of the file extensions supported by the active backend.
        Most backends support png, pdf, ps, eps and svg
        By default, the figure will be saved in both png and eps format.

    """

    if not handle:
        handle = plt.gcf()  # by default, save current figure

    if fname.endswith(('.png', '.eps', '.ps', '.jpeg', '.jpg')):
        handle.savefig(fname, **kwargs)
        print("Saving figure as '{} ...".format(fname))
    elif format:
        handle.savefig(fname + '.' + format, **kwargs)
        print("Saving figure as '{}.{}...".format(fname, format))
    else:   # foramt=None
        print("Saving figure as '{}(.png|.eps) ...".format(fname))
        handle.savefig(fname + '.png', **kwargs)
        handle.savefig(fname + '.eps', **kwargs)
